MNIST Geodesics - 5 digits

In this notebook, we use the first 5 MNIST digits to compare the different approaches we have to approximate the Riemannian distance:

  1. Riemannian length of the Euclidean interpolation
  2. Discrete Geodesic Algorithm from Shao et al. (2017)
  3. ODE from the Latent Space Oddity paper
  4. Shortest path in the latent graph
  5. Using the shortest path as an initialization for the discrete geodesic algorithm

Conclusion

Stochastic Riemannian length of shortest curves found:

  • Method 1 (Euclidean Interpolation): 161.9
  • Method 2 (Discrete Geodesic): 116.1
  • Method 3 (ODE): did not converge after 32 min and 1000+ nodes
  • Method 4 (Graph): 48.4
  • Method 5 (Graph as initialization for discrete): 43.2

Time per 100 geodesic approximations:

  • Method 1 (Euclidean Interpolation): 2.4s
  • Method 2 (Discrete Geodesic): 8min 12s
  • Method 4 (Graph): 6.24s
  • Method 5 (Graph as initialization for discrete): 2min 35s

Imports and setup of plotting library

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
    width=700,
    height=500,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    showlegend=False
)
config={'showLink': False}

# Make results completely repeatable
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)

digit_classes = [0,1,2,3,4]
/Users/kilian/dev/tum/2018-mlic-kilian/venv/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Create the VAE

following the implementation details in appendix D in the Latent Space Oddity paper.

In [2]:
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda
from src.vae import VAE
from src.rbf import RBFLayer

# Implementation details from Appendix D
input_dim = 784
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)

# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Dense(64, activation='tanh', kernel_regularizer=l2_reg)
enc_mean = Sequential([
    enc_shared,
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(latent_dim, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
    enc_shared,
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(latent_dim, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))

# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(64, activation='tanh', kernel_regularizer=l2_reg),
    Dense(input_dim, activation='sigmoid', kernel_regularizer=l2_reg)
])
dec_mean = Model(dec_input, dec_mean(dec_input))

# Build the RBF network
num_centers = 64
a = 1.0
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))

vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.

Filter the digits from MNIST

In [3]:
from tensorflow.python.keras.datasets import mnist

# Train the VAE on MNIST digits
(x_train_all, y_train_all), _ = mnist.load_data()

# Filter the digit classes from the mnist data
x_train = []
y_train = []
for digit_class in digit_classes:
    for x, y in zip(x_train_all, y_train_all):
        if y == digit_class:
            x_train.append(x)
            y_train.append(y)
            
x_train = np.array(x_train).astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
y_train = np.array(y_train)

# Shuffle the data
p = np.random.permutation(len(x_train))
x_train = x_train[p]
y_train = y_train[p]

Train the VAE

without training the generator's variance network. This will be trained separately later.

In [4]:
history = vae.model.fit(x_train,
              epochs=100,
              batch_size=32,
              validation_split=0.1,
              verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
       go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)

Visualize the latent space

In [5]:
# Display a 2D plot of the classes in the latent space
encoded_sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)

# Plot
scatter_data = []
colors = ['#825446', '#26A69A', '#F7B554', '#2D4366', '#84439A']
plot_points = []
for i_class, digit_class in enumerate(digit_classes):
    # Filter 500 points of this class from the training data
    class_hits = [y == digit_class for y in y_train]
    class_indices = np.arange(len(class_hits))[class_hits]
    class_indices = class_indices[:500]
    x_class = encoded_mean[class_indices]
    plot_points.extend(x_class)
    # Plot
    scatter_data.append(go.Scatter(
        x = x_class[:, 0],
        y = x_class[:, 1],
        mode = 'markers',
        marker = {'color': colors[i_class]},
        name = digit_class,
        hoverinfo = 'text',
        text = class_indices
    ))
    
iplot(go.Figure(data=scatter_data, layout=layout), config=config)
plot_points = np.array(plot_points)

Train the generator's variance network

For this, we first have to find the centers of the latent points.

In [6]:
from sklearn.cluster import KMeans

# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_

# Visualize the centers
center_plot = go.Scatter(
    x = centers[:, 0],
    y = centers[:, 1],
    mode = 'markers',
    marker = {'color': 'red'}
)
data = scatter_data + [center_plot] 
iplot(go.Figure(data=data, layout=layout), config=config)

Compute the bandwidths

In [7]:
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
    clustering[c_i].append(z_i)
    
bandwidths = []
for c_i, cluster in clustering.items():
    if cluster:
        diffs = np.array(cluster) - centers[c_i]
        avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
        bandwidth = 0.5 / (a * avg_dist)**2
    else:
        bandwidth = 0
    bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)

Train the variance network

In [8]:
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])

history = vae.model.fit(x_train,
                        epochs=100,
                        batch_size=32,
                        validation_split=0.1,
                        verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'],
                   name='Train Loss'),
        go.Scatter(y=history.history['val_loss'],
                   name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
In [9]:
from src.util import wrap_model_in_float64

# Get the mean and std predictors
_, mean_output, var_output = vae.decoder.output
sqrt_layer = Lambda(tf.sqrt)
dec_mean = Model(vae.decoder.input, mean_output)
dec_std = Model(vae.decoder.input, sqrt_layer(var_output))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)

session = tf.keras.backend.get_session()

Choose two latent points

for finding a geodesic.

In [10]:
z_start_index = 1586
z_end_index = 2051
z_start, z_end = plot_points[[z_start_index, z_end_index]] 

# Visualize the centers
task_plot = go.Scatter(
    x = [z_start[0], z_end[0]],
    y = [z_start[1], z_end[1]],
    mode = 'markers',
    marker = {'color': 'd32f2f'}
)
data = scatter_data + [task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Plot the magnification factors

In [11]:
from src.plot import plot_magnification_factor

heatmap_z1 = np.linspace(-5, 5, 100)
heatmap_z2 = np.linspace(-5, 5, 100)
heatmap = plot_magnification_factor(session, 
                                    heatmap_z1,
                                    heatmap_z2, 
                                    dec_mean, 
                                    dec_std, 
                                    additional_data=scatter_data + [task_plot],
                                    layout=layout,
                                    log_scale=True)
Computing Magnification Factors: 100%|██████████| 500/500 [00:01<00:00, 252.30it/s]

Define the evaluation metric

Before we start comparing the geodesic approximations, we need to define the metric. For each curve, we take equidistant steps in the latent space in order to compute the Riemannian length using numerical integration. We also plot the curve velocity.

In [12]:
from src.util import get_length_op, get_lengths_op, interpolate

curve_ph = tf.placeholder(tf.float64, [None, 2])
length_op, _ = get_length_op(curve_ph, dec_mean, dec_std)
lengths_op = get_lengths_op(curve_ph, dec_mean, dec_std)
lengths_op = tf.squeeze(lengths_op)

def evaluate_curve(curve, num_nodes=200, with_velocity_plot=True, 
                   verbose=True):
    curve = interpolate(curve, num_nodes)
    lengths = session.run(lengths_op, feed_dict={curve_ph: curve})
    length = np.sum(lengths)
    if verbose:
        print('Curve length: ', length)
    
    if with_velocity_plot:
        plot_velocity(lengths)
        
    return length
    
def plot_velocity(lengths):
    num_nodes = len(lengths)
    velocities = lengths * (num_nodes - 1)
    trace = go.Scatter(
        x = np.linspace(0, 1, num_nodes),
        y = velocities
    )
    iplot(go.Figure(data=[trace], layout=go.Layout(
        width=700,
        height=100,
        margin=go.Margin(l=60, r=60, b=20, t=20),
        showlegend=False
    )), config=config)

Method 1 - Euclidean Interpolation

In [13]:
t_nodes = np.linspace(0, 1, 50)
euclidean_curve = z_start + np.outer(t_nodes, z_end - z_start)
In [14]:
evaluate_curve(euclidean_curve)
Curve length:  161.86052278263105
Out[14]:
161.86052278263105

Method 2 - Discrete Geodesics

In [15]:
%%time
from src.discrete import find_geodesic_discrete

discrete_curve, discrete_iterations = find_geodesic_discrete(
    session, 
    z_start, z_end, 
    dec_mean, 
    std_generator=dec_std,
    num_nodes=50,
    max_steps=400,
    learning_rate=0.01,
    log_every=50,
    save_every=30)

print('-' * 20)
Step 0, Length 164.186722, Energy 45698.287936, Max velocity ratio 93.277833
Step 50, Length 56.084127, Energy 2100.714517, Max velocity ratio 27.859792
Step 100, Length 46.432401, Energy 1637.027999, Max velocity ratio 45.786492
Step 150, Length 45.221722, Energy 1539.038904, Max velocity ratio 39.928444
Step 200, Length 44.321481, Energy 1483.737693, Max velocity ratio 44.767189
Step 250, Length 43.784687, Energy 1453.031043, Max velocity ratio 43.603208
Step 300, Length 43.527600, Energy 1427.870480, Max velocity ratio 35.241766
Step 350, Length 43.432894, Energy 1404.505522, Max velocity ratio 30.820729
Step 400, Length 43.407710, Energy 1377.627264, Max velocity ratio 26.279630
--------------------
CPU times: user 13.2 s, sys: 4.59 s, total: 17.8 s
Wall time: 6.58 s
In [16]:
evaluate_curve(discrete_curve)
Curve length:  116.07069534937813
Out[16]:
116.07069534937813

Like in 14-graph-geodesics-moons.ipynb, the discrete geodesic algorithm's length estimate is strongly biased due to jumps over regions of large Riemannian metrics. Therefore, the actual curve length, 116, is higher than the length it converges to (43.4).

Method 3 - ODE

The ODE often does not converge. If it does, it takes orders of magnitue longer than the discrete geodesic algorithm without reaching a better solution. This is because the ODE only finds local minima just like the discrete geodesic algorithm. Neither of those methods searches globally for a solution.

In [17]:
%%time
from src.geodesic import find_geodesic

ode_result, ode_iterations = find_geodesic(session, z_start, z_end, 
                                           dec_mean, std_generator=dec_std, 
                                           initial_nodes=20, max_nodes=1000,
                                           use_fun_jac=True)
print('-' * 20)
   Iteration    Max residual    Total nodes    Nodes added  
       1          1.03e+02          20             38       
       2          9.62e+01          58             113      
       3          1.04e+03          171            340      
       4          1.84e+03          511          (1020)     
Number of nodes is exceeded after iteration 4, maximum relative residual 1.84e+03.
--------------------
CPU times: user 52min 9s, sys: 25min 22s, total: 1h 17min 31s
Wall time: 32min 43s
In [18]:
from src.plot import plot_latent_curve_iterations
plot_latent_curve_iterations(ode_iterations[::10], [heatmap] + scatter_data, 
                             layout, step_size=10)
In [19]:
ode_curve = ode_result.sol(ode_result.x)[0:2].T
evaluate_curve(ode_curve)
Curve length:  5922281.465610336
Out[19]:
5922281.465610336

Method 4 - Graph

We use the 500 points per digit from the latent plots above and add three times as many random gaussian noise points. This gives a total of 10,000 points in the latent space, which we will use for our graph in the latent space

In [20]:
extensions = [plot_points + np.random.randn(*plot_points.shape) 
              for _ in range(3)]
graph_points = np.concatenate([plot_points] + extensions)
print(graph_points.shape)
(10000, 2)

Compute the Riemannian distances of neighboring points

To get the nearest neighbors of each point, we use the get_neighbors function from src.graph. It is explained and defined in 13-graph-geodesics.ipynb.

Given the get_neighbors function, compute the Riemannian distance between each point and each of its neighbors. We approximate the Riemannian distance with a single midpoint for integration: $\int_0^1 \left\| J_{\gamma_t} \dot{\gamma}_t \right\| \mathrm{d}t \approx \left\| J_{\gamma_t} \dot{\gamma}_t \right\|$

In [21]:
import networkx as nx
from tqdm import tqdm

from src.util import get_metric_op
from src.graph import get_neighbors

point_ph = tf.placeholder(tf.float64, [2])
metric_op = get_metric_op(point_ph, dec_mean, dec_std)

# Compute the distance between the kNNs in Euclidean space
k = 4
graph = nx.Graph()
for i_point, point in enumerate(graph_points):
    graph.add_node(i_point, pos=point)
    
for i_point, point in enumerate(tqdm(graph_points)):
    neighbor_indices = get_neighbors(i_point, graph_points, k)
    
    for i_neighbor in neighbor_indices:
        if graph.has_edge(i_neighbor, i_point): 
            continue
        
        neighbor = graph_points[i_neighbor]
        middle = point + 0.5 * (neighbor - point) 
        velocity = neighbor - point
        metric = session.run(metric_op, feed_dict={point_ph: middle})
        length = velocity.T.dot(metric).dot(velocity)
        length = np.sqrt(length)
        graph.add_edge(i_point, i_neighbor, weight=length) 
100%|██████████| 10000/10000 [00:55<00:00, 181.38it/s]

Visualize a subgraph

and the relative weight of the edges (Riemannian length divided by Euclidean length). Green means a low relative weight, red means a large relative weight.

In [22]:
from src.plot import plot_graph_with_edge_colors, plot_graph

x_range = [min(z_start[0], z_end[0]), max(z_start[0], z_end[0])]
y_range = [min(z_start[1], z_end[1]), max(z_start[1], z_end[1])]

subnodes = []
for node in graph.nodes():
    pos = graph.node[node]['pos']
    if (x_range[0] <= pos[0] <= x_range[1] and
        y_range[0] <= pos[1] <= y_range[1]):
        subnodes.append(node)

subgraph = graph.subgraph(subnodes)
graph_plot = plot_graph_with_edge_colors(graph, layout=layout,
                                         additional_data=[task_plot])

Compute the shortest path

between the two points from above.

In [23]:
%%time
from networkx.algorithms.shortest_paths.generic import shortest_path
path = shortest_path(graph, z_start_index, z_end_index, weight='weight') #2
length = 0
for source, sink in zip(path[:-1], path[1:]):
    length += graph[source][sink]['weight']
print('Path length:', length)
print('-' * 20)
Path length: 48.84097456114202
--------------------
CPU times: user 116 ms, sys: 2.6 ms, total: 119 ms
Wall time: 118 ms

Visualize the shortest path

In [24]:
from src.plot import plot_graph

# Construct a subgraph from the path
path_graph = nx.Graph()
for point in path:
    path_graph.add_node(point, pos=graph_points[point])
for source, sink in zip(path[:-1], path[1:]):
    weight = graph[source][sink]['weight']
    path_graph.add_edge(source, sink, weight=weight) 

_ = plot_graph(path_graph, layout=layout, edge_color='#00DD00', 
               node_color='#00DD00', additional_data=[heatmap] + scatter_data)

Measure the actual curve length

Since we only computed the Riemannian distance for each edge using a single midpoint, the graph length is not exactly correct. It is not as strongly biased as the discrete geodesic algorithm's length estimate, but we should measure it as well with the interpolate function for a fair comparison.

In [25]:
graph_curve = graph_points[path]
evaluate_curve(graph_curve)
Curve length:  48.396277542785995
Out[25]:
48.396277542785995

Plot the shortest path and the discrete solution next to each other

In [26]:
# Plot the graph curve
graph_curve_plot = go.Scatter(
    x=graph_curve[:, 0],
    y=graph_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': '#3CA64D'}
)
# Plot the discrete curve
discrete_curve_plot = go.Scatter(
    x=discrete_curve[:, 0],
    y=discrete_curve[:, 1],
    mode='lines',
    line={'width': 5, 'color': '#d32f2f'}
)
data = [heatmap] + scatter_data + [graph_curve_plot, 
        discrete_curve_plot, task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Method 5 - Graph as init to discrete algorithm

In [27]:
graphref_curve, _ = find_geodesic_discrete(
    session, 
    z_start, z_end, 
    dec_mean, 
    std_generator=dec_std,
    num_nodes=50,
    max_steps=50,
    learning_rate=0.01,
    log_every=10,
    save_every=30,
    curve_init=graph_curve)
Step 0, Length 46.497732, Energy 1542.773158, Max velocity ratio 15.636960
Step 10, Length 42.500477, Energy 1090.838101, Max velocity ratio 6.449714
Step 20, Length 42.133813, Energy 1020.233917, Max velocity ratio 4.056094
Step 30, Length 42.085070, Energy 990.533737, Max velocity ratio 3.537869
Step 40, Length 42.042742, Energy 981.345241, Max velocity ratio 3.048412
Step 50, Length 41.851993, Energy 971.681278, Max velocity ratio 2.925967
In [28]:
evaluate_curve(graphref_curve)
Curve length:  43.24855327805831
Out[28]:
43.24855327805831

Visualize the refined graph solution

In [29]:
plot_latent_curve_iterations([graphref_curve], [heatmap] + scatter_data, layout)

Multiple Points Benchmark

Measure the runtime of each approach on 100 random pairs of points. We don't use the ODE here, since it takes orders of magnitude longer than the discrete geodesic algorithm without giving better geodesic approximations.

In [30]:
z_starts = np.random.choice(len(plot_points), 100)
z_ends = np.random.choice(len(plot_points), 100)

Method 1 - Euclidean Interpolation

In [31]:
def test_euclidean(z_starts, z_ends):
    curve_ph = tf.placeholder(tf.float64, [None, 2])
    length_op = get_length_op(curve_ph, dec_mean, dec_std)
    curves = []
    lengths = []

    for z_start, z_end in zip(plot_points[z_starts], plot_points[z_ends]):
        t_nodes = np.linspace(0, 1, 20)
        curve = z_start + np.outer(t_nodes, z_end - z_start)
        length, _ = session.run(length_op, feed_dict={curve_ph: curve})
        curves.append(curve)
        lengths.append(length)
    return curves, lengths

Method 2 - Discrete Geodesics

In [32]:
from src.discrete import find_geodesics_discrete

def test_discrete(z_starts, z_ends):
    return find_geodesics_discrete(
        session, 
        plot_points[z_starts], plot_points[z_ends], 
        dec_mean, 
        std_generator=dec_std,
        num_nodes=50,
        max_steps=400,
        learning_rate=0.01)

Method 4 - Graph

In [33]:
def test_graph(z_starts, z_ends):
    curves = []
    lengths = []
    for z_start, z_end in zip(z_starts, z_ends):
        path = shortest_path(graph, z_start, z_end, weight='weight')
        curve = graph_points[path]
        length = 0
        for source, sink in zip(path[:-1], path[1:]):
            length += graph[source][sink]['weight']
        curves.append(curve)
        lengths.append(length)
    return curves, lengths

Method 5 - Graph with refinement

In [34]:
def test_graph_refinement(z_starts, z_ends):
    curve_inits = []
    for z_start, z_end in zip(z_starts, z_ends):
        path = shortest_path(graph, z_start, z_end, weight='weight')
        curve = graph_points[path]
        curve_inits.append(curve)
    return find_geodesics_discrete(
        session, 
        encoded_mean[z_starts], encoded_mean[z_ends], 
        dec_mean, 
        std_generator=dec_std,
        num_nodes=50,
        max_steps=20,
        learning_rate=0.01,
        curve_inits=curve_inits)

Measure the time per geodesic

In [35]:
%%time
eucl_curves, eucl_est_lengths = test_euclidean(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 2.66 s, sys: 156 ms, total: 2.82 s
Wall time: 2.4 s
In [36]:
%%time
discrete_curves, discrete_est_lengths = test_discrete(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 19min 38s, sys: 7min 59s, total: 27min 38s
Wall time: 8min 12s
In [37]:
%%time
graph_curves, graph_est_lengths = test_graph(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 6.22 s, sys: 19.4 ms, total: 6.24 s
Wall time: 6.24 s
In [38]:
%%time
graphref_curves, graphref_est_lengths = test_graph_refinement(z_starts, z_ends)
print('-' * 20)
--------------------
CPU times: user 3min, sys: 24.9 s, total: 3min 25s
Wall time: 2min 35s

Measure the actual lengths

We see that the discrete geodesic algorithm gives strongly biased length estimates.

In [39]:
def evaluate_curves(curves, estimated_lengths, num_nodes=200):
    lengths = []
    estimation_errors = []
    for curve, estimated_length in zip(curves, estimated_lengths):
        curve = interpolate(curve, num_nodes)
        length = session.run(length_op, feed_dict={curve_ph: curve})
        lengths.append(length)
        estimation_errors.append(estimated_length - length)
        
    print('Estimation error mean: ', np.mean(estimation_errors))
    print('Estimation error std: ', np.std(estimation_errors))
    return lengths
In [40]:
eucl_lengths = evaluate_curves(eucl_curves, eucl_est_lengths)
Estimation error mean:  -0.18765667507361627
Estimation error std:  4.403350785167673
In [41]:
discrete_lengths= evaluate_curves(discrete_curves, discrete_est_lengths)
Estimation error mean:  -10.967965472835845
Estimation error std:  21.445188335827066
In [42]:
graph_lengths = evaluate_curves(graph_curves, graph_est_lengths)
Estimation error mean:  0.028590551557503677
Estimation error std:  0.5364480229271307
In [43]:
graphref_lengths = evaluate_curves(graphref_curves, graphref_est_lengths)
Estimation error mean:  -0.9519951586949094
Estimation error std:  1.5001728902702054

Plot the lengths

In [44]:
eucl_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(eucl_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': 'orange'},
    name = 'Euclidean'
)
graph_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(graph_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': '#3CA8FF'},
    name = 'Graph           '
)
discrete_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(discrete_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': '#d32f2f'},
    name = 'Discrete'
)
data = [eucl_trace, graph_trace, discrete_trace]
_layout = go.Layout(
    width=800,
    height=600,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    xaxis={
        'title': 'Length of graph solution',
        'titlefont': {'size': 18}
    },
    yaxis={
        'title': 'Stochastic Riemannian length',
        'titlefont': {'size': 18}
    },
    legend={
        'font': {'size': 18}
    }
)
iplot(go.Figure(data=data, layout=_layout), config=config)
In [45]:
graphref_trace = go.Scatter(
    x = graph_lengths,
    y = np.array(graphref_lengths),
    mode = 'markers',
    marker = {'size': 8, 'symbol': 'x', 'color': '#2D4366'},
    name = 'Graph Refinement'
)
data = [eucl_trace, graph_trace, discrete_trace, graphref_trace]
iplot(go.Figure(data=data, layout=_layout), config=config)

Conclusion

Stochastic Riemannian length of shortest curves found:

  • Method 1 (Euclidean Interpolation): 161.9
  • Method 2 (Discrete Geodesic): 116.1
  • Method 3 (ODE): did not converge after 32 min and 1000+ nodes
  • Method 4 (Graph): 48.4
  • Method 5 (Graph as initialization for discrete): 43.2

Time per 100 geodesic approximations:

  • Method 1 (Euclidean Interpolation): 2.4s
  • Method 2 (Discrete Geodesic): 8min 12s
  • Method 4 (Graph): 6.24s
  • Method 5 (Graph as initialization for discrete): 2min 35s